// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package validation

import (
	"fmt"
	"os"
	"path/filepath"
	"slices"

	"github.com/samber/lo"

	v1 "github.com/fatedier/frp/pkg/config/v1"
	"github.com/fatedier/frp/pkg/policy/featuregate"
	"github.com/fatedier/frp/pkg/policy/security"
)

func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
	var (
		warnings Warning
		errs     error
	)

	validators := []func() (Warning, error){
		func() (Warning, error) { return validateFeatureGates(c) },
		func() (Warning, error) { return validateAuthConfig(&c.Auth, unsafeFeatures) },
		func() (Warning, error) { return nil, validateLogConfig(&c.Log) },
		func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) },
		func() (Warning, error) { return validateTransportConfig(&c.Transport) },
		func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) },
	}

	for _, v := range validators {
		w, err := v()
		warnings = AppendError(warnings, w)
		errs = AppendError(errs, err)
	}
	return warnings, errs
}

func validateFeatureGates(c *v1.ClientCommonConfig) (Warning, error) {
	if c.VirtualNet.Address != "" {
		if !featuregate.Enabled(featuregate.VirtualNet) {
			return nil, fmt.Errorf("VirtualNet feature is not enabled; enable it by setting the appropriate feature gate flag")
		}
	}
	return nil, nil
}

func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
	var errs error
	if !slices.Contains(SupportedAuthMethods, c.Method) {
		errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods))
	}
	if !lo.Every(SupportedAuthAdditionalScopes, c.AdditionalScopes) {
		errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
	}

	// Validate token/tokenSource mutual exclusivity
	if c.Token != "" && c.TokenSource != nil {
		errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
	}

	// Validate tokenSource if specified
	if c.TokenSource != nil {
		if c.TokenSource.Type == "exec" {
			if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
				errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
					"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
			}
		}
		if err := c.TokenSource.Validate(); err != nil {
			errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
		}
	}

	if err := validateOIDCConfig(&c.OIDC, unsafeFeatures); err != nil {
		errs = AppendError(errs, err)
	}
	return nil, errs
}

func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.UnsafeFeatures) error {
	if c.TokenSource == nil {
		return nil
	}
	var errs error
	// Validate oidc.tokenSource mutual exclusivity with other fields of oidc
	if c.ClientID != "" || c.ClientSecret != "" || c.Audience != "" ||
		c.Scope != "" || c.TokenEndpointURL != "" || len(c.AdditionalEndpointParams) > 0 ||
		c.TrustedCaFile != "" || c.InsecureSkipVerify || c.ProxyURL != "" {
		errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
	}
	if c.TokenSource.Type == "exec" {
		if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
			errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
				"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
		}
	}
	if err := c.TokenSource.Validate(); err != nil {
		errs = AppendError(errs, fmt.Errorf("invalid auth.oidc.tokenSource: %v", err))
	}
	return errs
}

func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) {
	var (
		warnings Warning
		errs     error
	)

	if c.HeartbeatTimeout > 0 && c.HeartbeatInterval > 0 {
		if c.HeartbeatTimeout < c.HeartbeatInterval {
			errs = AppendError(errs, fmt.Errorf("invalid transport.heartbeatTimeout, heartbeat timeout should not less than heartbeat interval"))
		}
	}

	if !lo.FromPtr(c.TLS.Enable) {
		checkTLSConfig := func(name string, value string) Warning {
			if value != "" {
				return fmt.Errorf("%s is invalid when transport.tls.enable is false", name)
			}
			return nil
		}

		warnings = AppendError(warnings, checkTLSConfig("transport.tls.certFile", c.TLS.CertFile))
		warnings = AppendError(warnings, checkTLSConfig("transport.tls.keyFile", c.TLS.KeyFile))
		warnings = AppendError(warnings, checkTLSConfig("transport.tls.trustedCaFile", c.TLS.TrustedCaFile))
	}

	if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
		errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
	}
	return warnings, errs
}

func validateIncludeFiles(files []string) (Warning, error) {
	var errs error
	for _, f := range files {
		absDir, err := filepath.Abs(filepath.Dir(f))
		if err != nil {
			errs = AppendError(errs, fmt.Errorf("include: parse directory of %s failed: %v", f, err))
			continue
		}
		if _, err := os.Stat(absDir); os.IsNotExist(err) {
			errs = AppendError(errs, fmt.Errorf("include: directory of %s not exist", f))
		}
	}
	return nil, errs
}

func ValidateAllClientConfig(
	c *v1.ClientCommonConfig,
	proxyCfgs []v1.ProxyConfigurer,
	visitorCfgs []v1.VisitorConfigurer,
	unsafeFeatures *security.UnsafeFeatures,
) (Warning, error) {
	var warnings Warning
	if c != nil {
		warning, err := ValidateClientCommonConfig(c, unsafeFeatures)
		warnings = AppendError(warnings, warning)
		if err != nil {
			return warnings, err
		}
	}

	for _, c := range proxyCfgs {
		if err := ValidateProxyConfigurerForClient(c); err != nil {
			return warnings, fmt.Errorf("proxy %s: %v", c.GetBaseConfig().Name, err)
		}
	}

	for _, c := range visitorCfgs {
		if err := ValidateVisitorConfigurer(c); err != nil {
			return warnings, fmt.Errorf("visitor %s: %v", c.GetBaseConfig().Name, err)
		}
	}
	return warnings, nil
}
